# MIT License

# Copyright (c) 2024 dechin

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np

A_TO_BOHR = 1.889716164632
BOHR_TO_A = 0.5291772083

def read_out(file_name, idx=0, hat_lines=1, max_size=None, dlm=None):
    """ Read data from output files. """
    with open(file_name, 'r') as file:
        if max_size is None:
            lines = file.readlines()[hat_lines:]
        else:
            lines = file.readlines()[hat_lines: hat_lines + max_size]
    if isinstance(idx, int):
        cv = []
        for line in lines:
            l = line.strip()
            if dlm is None:
                cv.append(float(l.split()[idx]))
            else:
                cv.append(float(l.split(dlm)[idx]))
        cv = np.array(cv)
        return cv
    elif isinstance(idx, list):
        cv = []
        for line in lines:
            l = line.strip()
            cv_i = []
            for i in idx:
                if dlm is None:
                    cv_i.append(float(l.split()[i]))
                else:
                    cv_i.append(float(l.split(dlm)[i]))
            cv.append(cv_i)
        cv = np.array(cv)
        return cv
    else:
        raise ValueError("The data type of idx only support int and list.")
    
def save_fes(file_name, Z):
    """ Save the FES values. """
    np.savetxt(file_name, Z, delimiter=',')

def save_cube(file_name: str, 
              origin_vec: np.ndarray,
              x_grids: int, x_shift: float,
              y_grids: int, y_shift: float,
              z_grids: int, z_shift: float,
              Z: np.ndarray,
              use_bohr: bool = True):
    """ Save the FES values into cube format. """
    total_grids = x_grids * y_grids * z_grids
    if use_bohr:
        origin_vec *= A_TO_BOHR
        x_shift *= A_TO_BOHR
        y_shift *= A_TO_BOHR
        z_shift *= A_TO_BOHR
    center_x = x_grids * x_shift / 2 + origin_vec[0]
    center_y = y_grids * y_shift / 2 + origin_vec[1]
    center_z = z_grids * z_shift / 2 + origin_vec[2]

    with open(file_name, 'w') as file:
        file.write(f"Generated by CyFES\n")
        file.write(f"Total\t{total_grids}\tgrids\n")
        file.write(f"1\t{origin_vec[0]:.6g}\t{origin_vec[1]:.6g}\t{origin_vec[2]:.6g}\n")
        file.write(f"{x_grids}\t{x_shift:.6g}\t{0:.6g}\t{0:.6g}\n")
        file.write(f"{y_grids}\t{0:.6g}\t{y_shift:.6g}\t{0:.6g}\n")
        file.write(f"{z_grids}\t{0:.6g}\t{0:.6g}\t{z_shift:.6g}\n")
        file.write(f"{1}\t{1.0:.6f}\t{center_x:.6g}\t{center_y:.6g}\t{center_z:.6g}\n")
        i = 0
        while i < total_grids:
            for j in range(min(6, total_grids-i)):
                if (i+j) % z_grids == 0 and (i+j) != 0:
                    file.write(f'\n{Z[i+j]:.6g}\t')
                    i += j+1
                    break
                file.write(f'{Z[i+j]:.6g}\t')
                if j == min(6, total_grids-i)-1:
                    file.write('\n')
                    i += 6
    return 1

def save_dat(dat_name: str, 
             path: np.ndarray,
             Z: np.ndarray):
    with open(dat_name, 'w') as file:
        file.write(f"#! FIELDS x(Å) y(Å) z(Å) fes(kJ/mol)\n")
        for i in range(Z.shape[0]):
            file.write(f"{path[i][0]:.6g}\t{path[i][1]:.6g}\t{path[i][2]:.6g}\t{Z[i]:.6g}\n")
    return 1

def frange(x, y, jump):
    while x < y:
        yield x
    x += jump

def cube2xyz(file_name):
    at_coord=[]
    spacing_vec=[]
    nline=0
    values=[]
    # Read cube file and parse all data
    with open(file_name, 'r') as cube_file:
        lines = cube_file.readlines()

    for line in lines:
        nline+=1
        if nline==3:
            try:
                nat=int(line.split()[0]) 
                origin=[float(line.split()[1]),float(line.split()[2]),float(line.split()[3])]
            except:
                print ("ERROR: non recognized cube format")
        elif nline >3 and nline <= 6:
            spacing_vec.append(line.split())
        elif nline > 6 and nline <= 6+nat:
            at_coord.append(line.split())
        elif nline > 5:
            if nline > 6+nat:
                for i in line.split():
                    values.append(float(i)) 
    return origin, values, spacing_vec

def save2dat(dat_name, spacing_vec, origin, values):
    idx=-1
    with open(dat_name, 'w') as file:
        file.write("#! \tX\tY\tZ\tFES\n")
        for i in range(0, int(spacing_vec[0][0])):
            for j in range(0, int(spacing_vec[1][0])):
                for k in range(0, int(spacing_vec[2][0])):
                    idx += 1
                    x, y, z = origin[0] + i * float(spacing_vec[0][1]), origin[1] + j * float(spacing_vec[1][2]), origin[2] + k * float(spacing_vec[2][3])
                    file.write(f"{x*BOHR_TO_A:.6g}\t{y*BOHR_TO_A:.6g}\t{z*BOHR_TO_A:.6g}\t{values[idx]:.6g}\n")
    return 1